import numpy as np
from math import *

def single_expert_dynamics(state,action):
  x=np.copy(state[0].item())
  y=np.copy(state[1].item())
  if action==0:
    the_next_state=[x,y]
  elif action==1:
    if y==12:
      the_next_state=[x,y]
    else:
      the_next_state=[x,y+1]
  elif action==2:
    if y==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x,y-1]
  elif action==3:
    if x==0:
      the_next_state=[x,y]
    else:
      the_next_state=[x-1,y]
  elif action==4:
    if x==8:
      the_next_state=[x,y]
    else:
      the_next_state=[x+1,y]
  elif action==5:
    if x==0 or y==12:
      the_next_state=[x,y]
    else:
      the_next_state=[x-1,y+1]
  elif action==6:
    if x==8 or y==12:
      the_next_state=[x,y]
    else:
      the_next_state=[x+1,y+1]
  return np.mat(the_next_state).T

def choose_action(policy_distribution):  # distribution is 7x1
  choice=np.random.uniform()
  sum_value=0.0
  for a in range(num_action):
    sum_value=sum_value+policy_distribution[a]
    if sum_value>=choice:
      return a

def soft_policy(Q_matrix,V_matrix,num_action):
  distribution=np.zeros((9,13,num_action))
  distribution=distribution.astype(np.object)
  for x in range(9):
    for y in range(13):
      for a in range(num_action):
        distribution[x][y][a]=exp(Q_matrix[x][y][a])/exp(V_matrix[x][y])
  return distribution

def soft_Q_matrix_function(gamma,reward_matrix,lam,cost_matrix,V_matrix,num_action):
  Q_matrix=np.zeros((9,13,num_action))
  Q_matrix=Q_matrix.astype(np.object)
  for x in range(9):
    for y in range(13):
      for a in range(num_action):
        next_state=single_expert_dynamics(np.mat([x,y]).T,np.mat([a]).T)
        value=V_matrix[next_state.item(0)][next_state.item(1)]
        Q_matrix[x][y][a]=reward_matrix[x,y,a]-lam*cost_matrix[x,y]+gamma*value
  return Q_matrix

  
def soft_V_matrix_funciton(Q_matrix,num_action):
  V_matrix=np.zeros((9,13))
  V_matrix=V_matrix.astype(np.object)
  for x in range(9):
    for y in range(13):
      value=0.0
      for a in range(num_action):
        value=value+exp(Q_matrix[x][y][a])
      V_matrix[x][y]=log(value)
  return V_matrix

def calculate_soft_policy(reward1_matrix,reward2_matrix,theta,lam,gamma,num_action):
  cost1_matrix=np.zeros((9,13))
  cost2_matrix=np.zeros((9,13))

  cost1_matrix=cost1_matrix.astype(np.object)
  cost2_matrix=cost2_matrix.astype(np.object)

  for x in range(9):
    for y in range(13):
      cost1_matrix[x,y]=theta[x,y]*5.0
      cost2_matrix[x,y]=theta[x,y]*5.0


  soft_V1_matrix=np.zeros((9,13))
  soft_V1_matrix=soft_V1_matrix.astype(np.object)
  soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,lam,cost1_matrix,soft_V1_matrix,num_action))
  new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
  soft_V2_matrix=np.zeros((9,13))
  soft_V2_matrix=soft_V2_matrix.astype(np.object)
  soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,lam,cost2_matrix,soft_V2_matrix,num_action))
  new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))

  max_value1=0.0
  max_value2=0.0
  for x in range(9):
    for y in range(13):
      if max_value1<abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y]):
        max_value1=abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y])
      if max_value2<abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y]):
        max_value2=abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y])
  while max_value1>1 or max_value2>1:
    #print(max_value2)
    soft_V1_matrix=np.copy(new_soft_V1_matrix)
    soft_Q1_matrix=np.copy(soft_Q_matrix_function(gamma,reward1_matrix,lam,cost1_matrix,soft_V1_matrix,num_action))
    new_soft_V1_matrix=np.copy(soft_V_matrix_funciton(soft_Q1_matrix,num_action))
    soft_V2_matrix=np.copy(new_soft_V2_matrix)
    soft_Q2_matrix=np.copy(soft_Q_matrix_function(gamma,reward2_matrix,lam,cost2_matrix,soft_V2_matrix,num_action))
    new_soft_V2_matrix=np.copy(soft_V_matrix_funciton(soft_Q2_matrix,num_action))
    max_value1=0.0
    max_value2=0.0
    for x in range(9):
      for y in range(13):
        if max_value1<abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y]):
          max_value1=abs(soft_V1_matrix[x][y]-new_soft_V1_matrix[x][y])
        if max_value2<abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y]):
          max_value2=abs(soft_V2_matrix[x][y]-new_soft_V2_matrix[x][y])
  policy1=np.copy(soft_policy(soft_Q1_matrix,new_soft_V1_matrix,num_action))
  policy2=np.copy(soft_policy(soft_Q2_matrix,new_soft_V2_matrix,num_action))
  return policy1, policy2

def KL(policy1,policy2,expert_policy1,expert_policy2):
  divergence=0
  for x in range(9):
    for y in range(13):
      for a in range(7):
        if expert_policy1[x,y,a]!=0 and policy1[x,y,a]!=0:
          divergence=divergence+expert_policy1[x,y,a]*log(expert_policy1[x,y,a]/policy1[x,y,a])
        if expert_policy2[x,y,a]!=0 and policy2[x,y,a]!=0:
          divergence=divergence+expert_policy2[x,y,a]*log(expert_policy2[x,y,a]/policy2[x,y,a])
  return divergence/(9*9*2)

def trial(initial_state,policy1,policy2,num_action):
  trajectory1=[]
  trajectory2=[]
  state=initial_state
  for i in range(30):
    policy1_distribution=policy1[state.item(0)][state.item(1)][:]
    action1=choose_action(policy1_distribution)
    next_state1=single_expert_dynamics(state[0:2],np.mat([action1]).T)
    policy2_distribution=policy2[state.item(2)][state.item(3)][:]
    action2=choose_action(policy2_distribution)
    next_state2=single_expert_dynamics(state[2:4],np.mat([action2]).T)
    trajectory1.append([state.item(0),state.item(1),action1])
    trajectory2.append([state.item(2),state.item(3),action2])
    state=np.copy(np.vstack((next_state1,next_state2)))
  return trajectory1, trajectory2

def collision(x,y):
  if x>=2 and x<=6 and y==3:
    return True
  elif x>=3 and x<=5 and y==7:
    return True
  else:
    return False

def violation_success_rate(theta):
  theta1=np.copy(theta)
  theta1[1,12]=0.0
  theta1[7,12]=0.0
  policy1,policy2=calculate_soft_policy(reward1_matrix,reward2_matrix,theta1,10.0,gamma,num_action)
  collision_times=0.0
  for i in range(50):
    trajectory1,trajectory2=trial(initial_state,policy1,policy2,num_action)
    length1=len(trajectory1)
    length2=len(trajectory2)
    for j in range(length1):
      if collision(trajectory1[j][0],trajectory1[j][1]):
        collision_times=collision_times+1.0
        break
    for j in range(length2):
      if collision(trajectory2[j][0],trajectory2[j][1]):
        collision_times=collision_times+1.0
        break
  return collision_times/100, 1-collision_times/100
    

def likelihood_function(policy1,policy2,gamma):
  likelihood=0.0
  for i in range(9):
    trajectory1=np.loadtxt("expert1_trajectory_file"+str(i)+".txt",dtype=float)
    trajectory2=np.loadtxt("expert2_trajectory_file"+str(i)+".txt",dtype=float)
    length1=int((1.0*len(trajectory1))/3.0)
    length2=int((1.0*len(trajectory2))/3.0)
    for j in range(length1):
      x1=int(trajectory1[3*j])
      y1=int(trajectory1[3*j+1])
      a1=int(trajectory1[3*j+2])     
      if policy1[x1,y1,a1]!=0:
        likelihood=likelihood+gamma**j*log(policy1[x1,y1,a1])
    for j in range(length2):
      x2=int(trajectory2[3*j])
      y2=int(trajectory2[3*j+1])
      a2=int(trajectory2[3*j+2])     
      if policy2[x2,y2,a2]!=0:
        likelihood=likelihood+gamma**j*log(policy2[x2,y2,a2])
  return likelihood

def false_positive_negative(theta):
  positive=0
  negative=0
  for x in range(9):
    for y in range(13):
      if theta[x,y]>0.05:
        if x==1 and y==12:
          pass
        elif x==7 and y==12: 
          pass
        elif x>=2 and x<=6 and y==3:
          negative=negative+1.0
        elif x>=3 and x<=5 and y==7:
          negative=negative+1.0
        else:
          positive=positive+1.0
  return positive/109, (8-negative)/8

def constraint_learning(reward1_matrix,reward2_matrix,gamma,num_action):
  theta=np.zeros((9,13))
  divergence=10.0
  iterations=14
  save_likelihood_list=np.zeros((iterations+1,1))
  false_positive_list=np.zeros((iterations+1,1))
  false_negative_list=np.zeros((iterations+1,1))
  violation_rate_list=np.zeros((iterations+1,1))
  success_rate_list=np.zeros((iterations+1,1))
  old_policy1,old_policy2=calculate_soft_policy(reward1_matrix,reward2_matrix,theta,1.0,gamma,num_action)
  save_likelihood_list[0]=likelihood_function(old_policy1,old_policy2,gamma)
  violation_rate, success_rate=violation_success_rate(theta)
  violation_rate_list[0]=violation_rate
  success_rate_list[0]=success_rate
  false_positive_list[0]=0.0
  false_negative_list[0]=1.0
  for i in range(iterations):
    print('iteration={}' .format(i))
    likelihood_list=np.ones((9,13))*(-1000000000000)
    for x in range(9):
      for y in range(13):
        if theta[x,y]==0.0:
          theta[x,y]=1.0
          policy1,policy2=calculate_soft_policy(reward1_matrix,reward2_matrix,theta,1.0,gamma,num_action)
          likelihood_list[x,y]=likelihood_function(policy1,policy2,gamma)
          theta[x,y]=0.0
    max_value=np.max(likelihood_list)
    save_likelihood_list[i+1]=max_value
    divergence=KL(policy1,policy2,old_policy1,old_policy2)
    print('divergence is {}' .format(divergence))
    for x in range(9):
      for y in range(13):
        if likelihood_list[x,y]==max_value:
          theta[x,y]=1.0
    print(theta)
    false_positive,false_negative=false_positive_negative(theta)
    old_policy1,old_policy2=calculate_soft_policy(reward1_matrix,reward2_matrix,theta,1.0,gamma,num_action)
    false_positive_list[i+1]=false_positive
    false_negative_list[i+1]=false_negative
    violation_rate, success_rate=violation_success_rate(theta)
    violation_rate_list[i+1]=violation_rate
    success_rate_list[i+1]=success_rate

    print('likelihood is {}' .format(max_value))
    print('violation rate is {}' .format(violation_rate))
    print('success rate is {}' .format(success_rate))
    print('false positive is', false_positive)
    print('false negative is', false_negative)

  likelihood_file=open("ME_likelihood_file.txt","w")
  for entry in save_likelihood_list:
    np.savetxt(likelihood_file,entry)
  likelihood_file.close()

  false_positive_file=open("ME_centralized_false_positive_file.txt","w")
  for entry in false_positive_list:
    np.savetxt(false_positive_file,entry)
  false_positive_file.close()

  false_negative_file=open("ME_centralized_false_negative_file.txt","w")
  for entry in false_negative_list:
    np.savetxt(false_negative_file,entry)
  false_negative_file.close()

  violation_rate_file=open("ME_centralized_violation_rate_file.txt","w")
  for entry in violation_rate_list:
    np.savetxt(violation_rate_file,entry)
  violation_rate_file.close()

  success_rate_file=open("ME_centralized_success_rate_file.txt","w")
  for entry in success_rate_list:
    np.savetxt(success_rate_file,entry)
  success_rate_file.close()



initial_state=np.mat([1,1,7,1]).T
num_action=7
gamma=0.9
num_trials=100

reward1_matrix=np.zeros((9,13,num_action))
reward2_matrix=np.zeros((9,13,num_action))

reward1_matrix[7,12,0]=50
reward2_matrix[1,12,0]=50

constraint_learning(reward1_matrix,reward2_matrix,gamma,num_action)

#theta=np.zeros((9,13))
#theta[2:7,3]=1.0
#theta[3:6,7]=1.0

#print(violation_success_rate(theta))












